# Q-Learning vs SARSA — Random GridWorld
What we will learn from this notebook:

- How to generate a simple grid environment? What components we need to define for the environment?
- Q-learning and SARSA implementations
- The effect of reward shaping and having terminal non-goal states (lava)
- The effect of randomizing the initial state
- Comparing off-policy on-policy methods
- The effect of temperature


In [None]:
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(0)


# How to generate a simple grid environment? What components we need to define for the environment?

In [None]:

class GridWorld:
    def __init__(self, width=8, height=6, wall_prob=0.15, lava_prob=0.05,
                 random_seed=None, per_step_penalty=0.0, lava_terminal=True, random_initial=True):
        self.width = width
        self.height = height
        self.rng = np.random.RandomState(random_seed)
        self.wall_prob = wall_prob
        self.lava_prob = lava_prob
        self.per_step_penalty = per_step_penalty
        self.lava_terminal = lava_terminal
        self.random_initial = random_initial
        self.reset_grid()

    def reset_grid(self):
        self.grid = np.zeros((self.height, self.width), dtype=int)
        for r in range(self.height):
            for c in range(self.width):
                if self.rng.rand() < self.wall_prob:
                    self.grid[r, c] = 1
                elif self.rng.rand() < self.lava_prob:
                    self.grid[r, c] = 2
        ## TODO: detect which cells are empty and can be used as the initial state. Name those coordinates free
        self.free_forstart = free
        # choose start and goal distinct
        self.start = free[self.rng.randint(len(free))]
        while True:
            g = free[self.rng.randint(len(free))]
            if g != self.start:
                self.goal = g
                break
        self.grid[self.start] = 3
        self.grid[self.goal] = 4
        self.agent_pos = self.start
        self.actions = [(0,1),(1,0),(0,-1),(-1,0)]
        self.action_names = ['R','D','L','U']

    def in_bounds(self, pos):
        r,c = pos
        return 0 <= r < self.height and 0 <= c < self.width

    def step(self, action):
        ## How can we add teleport cells?
        dr,dc = self.actions[action]
        nr, nc = self.agent_pos[0]+dr, self.agent_pos[1]+dc
        if not self.in_bounds((nr,nc)):
            reward = -1.0 + self.per_step_penalty
            done = False
            return self.agent_pos, reward, done, {}
        cell = self.grid[nr, nc]
        if cell == 1:
            reward = -2.0 + self.per_step_penalty
            done = False
            return self.agent_pos, reward, done, {}
        self.agent_pos = (nr,nc)
        if cell == 2:
            reward = -10.0 + self.per_step_penalty
            done = self.lava_terminal
            return self.agent_pos, reward, done, {}
        if cell == 4:
            reward = 10.0
            done = True
            return self.agent_pos, reward, done, {}
        reward = -0.1 + self.per_step_penalty
        done = False
        return self.agent_pos, reward, done, {}

    def reset(self):
        ## TODO: we want to have two modes: 
        # 1. random_initial = True: in this case, the self.agent_pos should be a random available location
        # 2. random_initial = True: in this case, the reset moves the agent to the self.start
        return self.agent_pos

    def render(self):
        disp = np.array(self.grid, dtype=object)
        for r in range(self.height):
            for c in range(self.width):
                if disp[r,c] == 0:
                    disp[r,c] = '.'
                elif disp[r,c] == 1:
                    disp[r,c] = '#'
                elif disp[r,c] == 2:
                    disp[r,c] = 'L'
                elif disp[r,c] == 3:
                    disp[r,c] = 'S'
                elif disp[r,c] == 4:
                    disp[r,c] = 'G'
        ar,ac = self.agent_pos
        disp[ar,ac] = 'A'
        for row in disp:
            print(' '.join(row))

    def state_to_idx(self, pos):
        return pos[0]*self.width + pos[1]


# Q-learning and SARSA implementations

In [None]:

def make_epsilon_greedy_policy(Q, nA, epsilon):
    ## TODO: write a function which gives you probablities of taking actions based on epsilon-greedy
    return policy_fn

def q_learning(env, num_episodes=500, alpha=0.5, gamma=0.99, epsilon=0.1, max_steps=200):
    ## TODO: Define a table (matrix) for Q values 
    policy = make_epsilon_greedy_policy(Q, nA, epsilon)
    rewards_history = []
    for i_episode in range(num_episodes):
        ## TODO: reset the environment and get the state s
        total_reward = 0.0
        for t in range(max_steps):
            probs = policy(s)
            ## TODO: choose a random action based on probabilities
            ## TODO: take the action
            s2 = env.state_to_idx(next_pos)
            ## TODO: update rule of q-learning
            total_reward += r
            s = s2
            if done:
                break
        rewards_history.append(total_reward)
    return Q, rewards_history

def sarsa(env, num_episodes=500, alpha=0.5, gamma=0.99, epsilon=0.1, max_steps=200):
    ## TODO: Define a table (matrix) for Q values 
    policy = make_epsilon_greedy_policy(Q, nA, epsilon)
    rewards_history = []
    for i_episode in range(num_episodes):
        ## TODO: reset the environment and get the state s
        probs = policy(s)
        ## TODO: choose a random action based on probabilities
        
        total_reward = 0.0
        for t in range(max_steps):
            ## TODO: take the action
            s2 = env.state_to_idx(next_pos)
            ## TODO: update rule of SARSA
            total_reward += r
            s, a = s2, a2
            if done:
                break
        rewards_history.append(total_reward)
    return Q, rewards_history


# Training loop, evaluating policy, and plotting the learning dynamics

In [None]:

def evaluate_policy(env, Q, episodes=50, max_steps=200):
    success = 0
    total_rewards = []
    for _ in range(episodes):
        s_pos = env.reset()
        s = env.state_to_idx(s_pos)
        total = 0.0
        for _ in range(max_steps):
            ## TODO: usually for evaluation we use greedy policy. choose an action based on greedy policy.
            next_pos, r, done, _ = env.step(a)
            total += r
            s = env.state_to_idx(next_pos)
            if done:
                if next_pos == env.goal:
                    success += 1
                break
        total_rewards.append(total)
    return success/episodes, np.mean(total_rewards), total_rewards

def plot_rewards(hist_q, hist_s, title='Training rewards'):
    plt.figure(figsize=(10,4))
    plt.plot(np.convolve(hist_q, np.ones(10)/10, mode='valid'), label='Q-learning (smoothed)')
    plt.plot(np.convolve(hist_s, np.ones(10)/10, mode='valid'), label='SARSA (smoothed)')
    plt.xlabel('Episode (smoothed window)')
    plt.ylabel('Total reward per episode')
    plt.legend()
    plt.title(title)
    plt.grid(True)
    plt.show()

def run_experiment(seed=0, width=8, height=6, wall_prob=0.15, lava_prob=0.05,
                   per_step_penalty=0.0, lava_terminal=True, episodes=500, random_initial=True, epsilon=0.1):
    env = GridWorld(width=width, height=height, wall_prob=wall_prob, lava_prob=lava_prob,
                    random_seed=seed, per_step_penalty=per_step_penalty, lava_terminal=lava_terminal, random_initial=random_initial)
    print('Grid (A=Agent, G=goal, #=wall, L=lava):')
    env.render()
    Q_q, hist_q = q_learning(env, num_episodes=episodes, alpha=0.6, gamma=0.99, epsilon=epsilon)
    env.reset()
    Q_s, hist_s = sarsa(env, num_episodes=episodes, alpha=0.6, gamma=0.99, epsilon=epsilon)
    env.reset()
    succ_q, mean_r_q, _ = evaluate_policy(env, Q_q, episodes=200)
    env.reset()
    succ_s, mean_r_s, _ = evaluate_policy(env, Q_s, episodes=200)
    print(f'Q-learning success rate: {succ_q:.2f}, avg reward: {mean_r_q:.2f}')
    print(f'SARSA       success rate: {succ_s:.2f}, avg reward: {mean_r_s:.2f}')
    plot_rewards(hist_q, hist_s, title=f'Learning curves | lava_prob={lava_prob} | per_step_penalty={per_step_penalty}')
    return env, Q_q, Q_s, hist_q, hist_s


# Plotting heatmaps for Q values and plotting the optimal policies

In [None]:
# Q-value plotting utilities
def plot_q_heatmaps(Q, env, title_prefix='Q-values'):
    nA = Q.shape[1]
    grids = [Q[:,a].reshape(env.height, env.width) for a in range(nA)]
    vmax = max(np.max(g) for g in grids)
    vmin = min(np.min(g) for g in grids)
    
    # Create a 3x3 grid layout for cross arrangement
    fig, axes = plt.subplots(3, 3, figsize=(20, 20))
    action_names = env.action_names
    
    # Define positions for each action plot in cross pattern
    # Center: (1,1), Up: (0,1), Down: (2,1), Left: (1,0), Right: (1,2)
    positions = [(1, 2),(2, 1),(1, 0),(0, 1)]  # left, right, up, down
    
    # Hide all axes initially
    for i in range(3):
        for j in range(3):
            axes[i, j].set_visible(False)
    
    # Plot action-specific Q-values around the center
    for i in range(min(nA, 4)):  # Limit to 4 actions for cross layout
        row, col = positions[i]
        ax = axes[row, col]
        ax.set_visible(True)
        
        im = ax.imshow(grids[i], origin='upper', vmin=vmin, vmax=vmax)
        ax.set_title(f'{title_prefix} — action {action_names[i]}')
        
        # Add grid annotations
        for r in range(env.height):
            for c in range(env.width):
                cell = env.grid[r,c]
                if cell == 1:
                    ax.text(c, r, '#', ha='center', va='center', fontsize=12, color='white')
                elif cell == 2:
                    ax.text(c, r, 'L', ha='center', va='center', fontsize=12, color='white')
                elif cell == 3:
                    ax.text(c, r, 'S', ha='center', va='center', fontsize=12, color='white')
                elif cell == 4:
                    ax.text(c, r, 'G', ha='center', va='center', fontsize=12, color='white')
        
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    # Plot state-value (max Q) in the center
    state_vals = np.max(Q, axis=1).reshape(env.height, env.width)
    center_ax = axes[1, 1]
    center_ax.set_visible(True)
    
    im = center_ax.imshow(state_vals, origin='upper')
    center_ax.set_title(f'{title_prefix} — state-value (max Q)')
    
    # Add grid annotations for center plot
    for r in range(env.height):
        for c in range(env.width):
            cell = env.grid[r,c]
            if cell == 1:
                center_ax.text(c, r, '#', ha='center', va='center', fontsize=12, color='white')
            elif cell == 2:
                center_ax.text(c, r, 'L', ha='center', va='center', fontsize=12, color='white')
            elif cell == 3:
                center_ax.text(c, r, 'S', ha='center', va='center', fontsize=12, color='white')
            elif cell == 4:
                center_ax.text(c, r, 'G', ha='center', va='center', fontsize=12, color='white')
    
    fig.colorbar(im, ax=center_ax, fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.show()


def plot_greedy_policy(Q, env, title='Greedy policy (from Q)'):
    best = np.argmax(Q, axis=1).reshape(env.height, env.width)
    plt.figure(figsize=(env.width, env.height/1.5))
    for r in range(env.height):
        for c in range(env.width):
            a = best[r,c]
            cell = env.grid[r,c]
            if cell == 1:
                plt.text(c, r, '#', ha='center', va='center', fontsize=12)
            elif cell == 2:
                plt.text(c, r, 'L', ha='center', va='center', fontsize=12)
            elif cell == 4:
                plt.text(c, r, 'G', ha='center', va='center', fontsize=12)
            elif cell == 3:
                plt.text(c, r, 'S', ha='center', va='center', fontsize=12)
            else:
                dx, dy = 0,0
                if a==0: dx=0.3; dy=0
                if a==1: dx=0; dy=0.3
                if a==2: dx=-0.3; dy=0
                if a==3: dx=0; dy=-0.3
                plt.arrow(c, r, dx, dy, head_width=0.12, head_length=0.12)
    plt.xlim(-0.5, env.width-0.5)
    plt.ylim(env.height-0.5, -0.5)
    plt.gca().set_aspect('equal', adjustable='box')
    plt.title(title)
    plt.axis('off')
    plt.show()


# The effect of reward shaping and having terminal non-goal states (lava)

In [None]:
random_initial = False
# Demo runs - modify parameters or call run_experiment yourself
env_a, Qqa, Qsa, hqa, hsa = run_experiment(seed=10, lava_prob=0.0, per_step_penalty=0.0, episodes=100000, random_initial=random_initial)
env_b, Qqb, Qsb, hqb, hsb = run_experiment(seed=10, lava_prob=0.0, per_step_penalty=-0.5, lava_terminal=True, episodes=100000, random_initial=random_initial)
env_c, Qqc, Qsc, hqc, hsc = run_experiment(seed=13, lava_prob=0.08, per_step_penalty=0.0, lava_terminal=True, episodes=100000, random_initial=random_initial)
env_d, Qqd, Qsd, hqd, hsd = run_experiment(seed=13, lava_prob=0.08, per_step_penalty=-0.5, lava_terminal=True, episodes=100000, random_initial=random_initial)

## Q-value visualization utilities
This section adds plotting utilities to visualize the learned Q-values:
- `plot_q_heatmaps(Q, env, title_prefix='')` — value heatmap (max Q) + per-action heatmaps
- `plot_policy_with_values(Q, env, title='')` — overlays best-action arrows on a value (V) heatmap

A demo cell at the bottom will try to use any already-trained Q variables (e.g. `Qqa`, `Qqb`, `Qqc`) or run a short training if none are found.

In [None]:
# Demo: try to find an existing trained Q (common names used in this notebook are Qqa, Qqb, Qqc, Q_q, Q_s, etc.)
runs = [(env_a,Qqa,'No Lava-No Penalty, Q Learning'), (env_a,Qsa, 'No Lava-No Penalty, SARSA'), (env_b,Qqb, 'No Lava-With Penalty, Q Learning'),
(env_b,Qsb, 'No Lava-With Penalty, SARSA'), (env_c,Qqc, 'With Lava-No Penalty, Q Learning'), (env_c,Qsc, 'With Lava-No Penalty, SARSA'),
(env_d,Qqd, 'With Lava-No Penalty, Q Learning'), (env_d,Qsd, 'With Lava-With Penalty, SARSA')]

found = None
for run in runs:
    
    env_var = run[0]
    print(run[2])
    env_var.reset()
    env_var.render()
    Q = run[1]
    plot_q_heatmaps(Q, env_var, title_prefix=f'{run[2]}')
    plot_greedy_policy(Q, env_var, title=f'Policy from {run[2]}')


# The effect of randomizing the initial state, and comapring off-policy on-policy algorithms

In [None]:
random_initial = True
# Demo runs - modify parameters or call run_experiment yourself
env_a, Qqa, Qsa, hqa, hsa = run_experiment(seed=10, lava_prob=0.0, per_step_penalty=0.0, episodes=100000, random_initial=random_initial)
env_b, Qqb, Qsb, hqb, hsb = run_experiment(seed=10, lava_prob=0.0, per_step_penalty=-0.5, lava_terminal=True, episodes=100000, random_initial=random_initial)
env_c, Qqc, Qsc, hqc, hsc = run_experiment(seed=13, lava_prob=0.08, per_step_penalty=0.0, lava_terminal=True, episodes=100000, random_initial=random_initial)
env_d, Qqd, Qsd, hqd, hsd = run_experiment(seed=13, lava_prob=0.08, per_step_penalty=-0.5, lava_terminal=True, episodes=100000, random_initial=random_initial)

# Demo: try to find an existing trained Q (common names used in this notebook are Qqa, Qqb, Qqc, Q_q, Q_s, etc.)
runs = [(env_a,Qqa,'No Lava-No Penalty, Q Learning'), (env_a,Qsa, 'No Lava-No Penalty, SARSA'), (env_b,Qqb, 'No Lava-With Penalty, Q Learning'),
(env_b,Qsb, 'No Lava-With Penalty, SARSA'), (env_c,Qqc, 'With Lava-No Penalty, Q Learning'), (env_c,Qsc, 'With Lava-No Penalty, SARSA'),
(env_d,Qqd, 'With Lava-No Penalty, Q Learning'), (env_d,Qsd, 'With Lava-With Penalty, SARSA')]

found = None
for run in runs:
    
    env_var = run[0]
    print(run[2])
    env_var.reset()
    env_var.render()
    Q = run[1]
    plot_q_heatmaps(Q, env_var, title_prefix=f'{run[2]}')
    plot_greedy_policy(Q, env_var, title=f'Policy from {run[2]}')


# The effect of exploration

In [None]:
random_initial = True
# Demo runs - modify parameters or call run_experiment yourself
w = 30
h = 30
env_a, Qqa, Qsa, hqa, hsa = run_experiment(width=w, height=h, seed=13, lava_prob=0.08, per_step_penalty=-0.8, episodes=10000, random_initial=random_initial, epsilon=0.0)
env_b, Qqb, Qsb, hqb, hsb = run_experiment(width=w, height=h, seed=13, lava_prob=0.08, per_step_penalty=-0.8, lava_terminal=True, episodes=10000, random_initial=random_initial, epsilon=0.1)
env_c, Qqc, Qsc, hqc, hsc = run_experiment(width=w, height=h, seed=13, lava_prob=0.08, per_step_penalty=-0.8, lava_terminal=True, episodes=10000, random_initial=random_initial, epsilon=0.3)
env_d, Qqd, Qsd, hqd, hsd = run_experiment(width=w, height=h, seed=13, lava_prob=0.08, per_step_penalty=-0.8, lava_terminal=True, episodes=10000, random_initial=random_initial, epsilon=0.7)

# Demo: try to find an existing trained Q (common names used in this notebook are Qqa, Qqb, Qqc, Q_q, Q_s, etc.)
runs = [(env_a,Qqa,'epsilon 0.01, Q Learning'), (env_a,Qsa, 'epsilon 0.01, SARSA'), (env_b,Qqb, 'epsilon 0.1, Q Learning'),
(env_b,Qsb, 'epsilon 0.1, SARSA'), (env_c,Qqc, 'epsilon 0.3, Q Learning'), (env_c,Qsc, 'epsilon 0.3, SARSA'),
(env_d,Qqd, 'epsilon 0.7, Q Learning'), (env_d,Qsd, 'epsilon 0.7, SARSA')]

found = None
for run in runs:
    
    env_var = run[0]
    print(run[2])
    env_var.reset()
    env_var.render()
    Q = run[1]
    plot_q_heatmaps(Q, env_var, title_prefix=f'{run[2]}')
    plot_greedy_policy(Q, env_var, title=f'Policy from {run[2]}')
